Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multilabel integration into task chains as attribute extraction #902

Merged
merged 5 commits into from
Sep 24, 2024

Conversation

DhruvaBansal00
Copy link
Contributor

Pull Review Summary

Description

A summary of the change. Please also include relevant motivation and context. This could include links to any docs/Slack threads/Github issues other artifacts.

Type of change

  • Bug fix (change which fixes an issue)
  • New feature (change which adds functionality)
  • This change requires a documentation update

Tests

Locally

Copy link
Contributor

@nihit nihit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@nihit nihit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm otherwise

@@ -76,6 +77,47 @@ def logprob_average(
count += 1
return logprob_cumulative ** (1.0 / count) if count > 0 else 0

def _logprob_average_per_label(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rajasbansal for review

@DhruvaBansal00
Copy link
Contributor Author

do we need to do something similar to this https://github.com/refuel-ai/autolabel/blob/main/src/autolabel/tasks/base.py#L224 ?

Fortunately no. We just return a semicolon separated list as the label for multilabel attributes and the product takes care of splitting them while displaying just like before. Confidence computation has to perform some splitting to get a confidence value for each key however.

for conf_label_candiate in conf_label_keys:
closest_match, closest_match_score = None, 0
for label in logprob_per_label:
longest_substring = difflib.SequenceMatcher(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment on what this function does/what the arguments mean here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comments!

longest_substring = difflib.SequenceMatcher(
None, label, conf_label_candiate
).find_longest_match(0, len(label), 0, len(conf_label_candiate))
if longest_substring.size > closest_match_score:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if one label is contained within another label i.e if the labels are
LessRed, MoreRed, Red
will that lead to an issue here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the metric to be proportion of the largest string that overlaps, instead of the number of characters. I think that should take care of this case.

Comment on lines -94 to -108
# Remove all characters before a '\n' character in the logprobs, as
# this is parsed during the generation process
# In this case if input logprobs is
# [{"xx\nc": -1.2},{"Ab\n": -1.2}, {"Abc": -1.2}, {";": -1.3}, {"B": -1.4}, {"cd": -1.6}, {";": -1.5}, {"C": -1.4}]
# The output logprobs would be [{"Abc": -1.2}, {";": -1.3}, {"B": -1.4}, {"cd": -1.6}, {";": -1.5}, {"C": -1.4}]
for i in range(len(logprobs) - 1, -1, -1):
cur_key = list(logprobs[i].keys())[0]
if "\n" in cur_key:
new_key = cur_key.split("\n")[-1].strip()
if not new_key:
logprobs = logprobs[i + 1 :]
else:
logprobs[i] = {new_key: logprobs[i][cur_key]}
logprobs = logprobs[i:]
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was needed because of a bug where sometime the model will output a lot of characters like \n or some explanation followed by a \n to give the label. We handle this in parse_llm_response by removing everything before the last \n. Is the expectation that this will never happen now because of guided generation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep with guided generation this shouldn't happen now

@@ -162,7 +169,7 @@ def logprob_average_per_key(

# Find the locations of each key in the logprobs as indices
# into the logprobs list
locations = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on why we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just implementation choice - allows to avoid setting logprob_per_key[locations[-1][2]] at the end explicitly.

f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list"
)
llm_label.pop(attribute["name"], None)
if attr_type == TaskType.CLASSIFICATION:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also check that every label that we get for multilabel is one of the options if the guidelines were followed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need this either moving forward with guided generation. Will be removing classification as well soon!

Copy link
Contributor

@rajasbansal rajasbansal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm if #902 (comment) is not an issue

@nihit
Copy link
Contributor

nihit commented Sep 22, 2024

thanks @DhruvaBansal00 lgtm from my end

@DhruvaBansal00 DhruvaBansal00 merged commit 7fb125b into main Sep 24, 2024
2 checks passed
@DhruvaBansal00 DhruvaBansal00 deleted the multilabel-taskchain-integration branch September 24, 2024 02:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants